# -*- coding: utf-8 -*-
"""
Created on Mon Feb 16 11:05:21 2026

@author: Michael Palace
"""
import pandas as pd
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from networkx.algorithms import bipartite
from networkx.algorithms.community import greedy_modularity_communities
from math import sqrt

# -------------------------------------------------
# LOAD DATA
# -------------------------------------------------
# If tab separated:
# df = pd.read_csv("your_file.csv", sep="\t")

# Example structure (REMOVE if loading real file)
data = {
    "TreeNum": [300,300,300],
    "Simulation Run": [2,2,2],
    "Female Tree": [15,15,15],
    "Male Tree": [108,140,141],
    "X": [1675,1675,1675],
    "Y": [18329,18329,18329],
    "X2": [6328,11093,1874],
    "Y2": [18362,16663,7666],
    "Day Number": [1,1,1]
}
df = pd.DataFrame(data)

# -------------------------------------------------
# OPTIONAL: FILTER TO ONE SIMULATION
# -------------------------------------------------
# df = df[df["Simulation Run"] == 2]

# -------------------------------------------------
# CREATE UNIQUE NODE IDS
# -------------------------------------------------
df["Female_ID"] = "F_" + df["Female Tree"].astype(str)
df["Male_ID"] = "M_" + df["Male Tree"].astype(str)

# -------------------------------------------------
# BUILD BIPARTITE GRAPH
# -------------------------------------------------
B = nx.Graph()

# Add nodes with bipartite attribute
B.add_nodes_from(df["Female_ID"].unique(), bipartite=0)
B.add_nodes_from(df["Male_ID"].unique(), bipartite=1)

# Edge weights = number of times interaction occurs
edge_weights = (
    df.groupby(["Female_ID", "Male_ID"])
      .size()
      .reset_index(name="weight")
)

for _, row in edge_weights.iterrows():
    B.add_edge(row["Female_ID"],
               row["Male_ID"],
               weight=row["weight"])

# -------------------------------------------------
# VERIFY BIPARTITE
# -------------------------------------------------
if not nx.is_bipartite(B):
    raise ValueError("Graph is not bipartite. Check your data.")

print("Graph is bipartite")
print("Total nodes:", B.number_of_nodes())
print("Total edges:", B.number_of_edges())

# -------------------------------------------------
# GET NODE SETS
# -------------------------------------------------
female_nodes = {n for n, d in B.nodes(data=True) if d["bipartite"] == 0}
male_nodes = set(B) - female_nodes

# -------------------------------------------------
# 1) BIPARTITE CLUSTERING (FEMALES)
# -------------------------------------------------
bip_clust = bipartite.clustering(B, nodes=female_nodes)
avg_bip_clust = np.mean(list(bip_clust.values()))
print("Average Bipartite Clustering:", avg_bip_clust)

# -------------------------------------------------
# 2) FEMALE PROJECTION
# -------------------------------------------------
G_female = bipartite.weighted_projected_graph(B, female_nodes)

print("Projected female nodes:",
      G_female.number_of_nodes())
print("Projected female edges:",
      G_female.number_of_edges())

# -------------------------------------------------
# 3) GREEDY MODULARITY
# -------------------------------------------------
communities = greedy_modularity_communities(G_female, weight="weight")

modularity = nx.algorithms.community.modularity(
    G_female, communities, weight="weight"
)

print("Modularity (female projection):", modularity)

# -------------------------------------------------
# 4) DISTANCE CALCULATION
# -------------------------------------------------
def euclidean(x1, y1, x2, y2):
    return sqrt((x2 - x1)**2 + (y2 - y1)**2)

df["Distance"] = df.apply(
    lambda row: euclidean(row["X"],
                          row["Y"],
                          row["X2"],
                          row["Y2"]),
    axis=1
)

# -------------------------------------------------
# 5) DISTANCE VS INTERACTION SCATTER
# -------------------------------------------------
plt.figure()
plt.scatter(df["Distance"], np.ones(len(df)))
plt.xlabel("Distance")
plt.ylabel("Interaction Event")
plt.title("Distance vs Interaction")
plt.show()

# -------------------------------------------------
# 6) OPTIONAL: PUBLICATION QUALITY NETWORK PLOT
# -------------------------------------------------
plt.figure(figsize=(6,6))

pos = nx.spring_layout(B, seed=42)

nx.draw_networkx_nodes(
    B, pos,
    node_color="lightgray",
    edgecolors="black",
    linewidths=1
)

nx.draw_networkx_edges(B, pos)

plt.axis("off")

plt.savefig("bipartite_network.png", dpi=1200)
plt.show()
